{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "toc": true
   },
   "source": [
    "<h1>Table of Contents<span class=\"tocSkip\"></span></h1>\n",
    "<div class=\"toc\"><ul class=\"toc-item\"><li><span><a href=\"#Load-Network\" data-toc-modified-id=\"Load-Network-1\"><span class=\"toc-item-num\">1&nbsp;&nbsp;</span>Load Network</a></span></li><li><span><a href=\"#Style-Mixing\" data-toc-modified-id=\"Style-Mixing-2\"><span class=\"toc-item-num\">2&nbsp;&nbsp;</span>Style Mixing</a></span></li><li><span><a href=\"#Latents-Transition/Morphing\" data-toc-modified-id=\"Latents-Transition/Morphing-3\"><span class=\"toc-item-num\">3&nbsp;&nbsp;</span>Latents Transition/Morphing</a></span></li><li><span><a href=\"#Explore-PSI\" data-toc-modified-id=\"Explore-PSI-4\"><span class=\"toc-item-num\">4&nbsp;&nbsp;</span>Explore PSI</a></span></li><li><span><a href=\"#Explore-Latents-Indexes\" data-toc-modified-id=\"Explore-Latents-Indexes-5\"><span class=\"toc-item-num\">5&nbsp;&nbsp;</span>Explore Latents Indexes</a></span></li></ul></div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Playground for experiments with StyleGANv2 latents.\n",
    "Includes interactive style mixing, latents interpolation or morphing and latents tweaking."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import sys\n",
    "import os\n",
    "from datetime import datetime\n",
    "from tqdm import tqdm\n",
    "import imageio\n",
    "\n",
    "# ffmpeg installation location, for creating videos\n",
    "plt.rcParams['animation.ffmpeg_path'] = str(Path.home() / \"Documents/dev_tools/ffmpeg-20190623-ffa64a4-win64-static/bin/ffmpeg.exe\")\n",
    "import ipywidgets as widgets\n",
    "from ipywidgets import interact, interact_manual\n",
    "from IPython.display import display\n",
    "from ipywidgets import Button\n",
    "\n",
    "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "# StyleGAN Utils\n",
    "from stylegan_utils import load_network, gen_image_fun, synth_image_fun, create_video\n",
    "\n",
    "# StyleGAN2 Repo\n",
    "sys.path.append(os.path.join(os.pardir, 'stylegan2encoder'))\n",
    "\n",
    "import run_projector\n",
    "import projector\n",
    "import training.dataset\n",
    "import training.misc\n",
    "\n",
    "# Data Science Utils\n",
    "sys.path.append(os.path.join(os.pardir, 'data-science-learning'))\n",
    "\n",
    "from ds_utils import generative_utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res_dir = Path.home() / 'Documents/generated_data/stylegan'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODELS_DIR = Path(\"C:/Users/User/Documents/models/stylegan2\")\n",
    "MODEL_NAME = 'original_ffhq'\n",
    "SNAPSHOT_NAME = 'stylegan2-ffhq-config-f'\n",
    "\n",
    "Gs, Gs_kwargs, noise_vars = load_network(str(MODELS_DIR / MODEL_NAME / SNAPSHOT_NAME) + '.pkl')\n",
    "\n",
    "Z_SIZE = Gs.input_shape[1:][0]\n",
    "IMG_SIZE = Gs.output_shape[2:]\n",
    "IMG_SIZE"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Style Mixing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# used when saving the currently displayed image\n",
    "current_displayed_latents = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_latents(latents):\n",
    "    # If not already numpy array, load the latents\n",
    "    if type(latents) is not np.ndarray:\n",
    "        latents = np.load(latents)\n",
    "    \n",
    "    # TMP fix for when saved latens as [1, 16, 512]\n",
    "    if len(latents.shape) == 3:\n",
    "        assert latents.shape[0] == 1\n",
    "        latents = latents[0]\n",
    "    \n",
    "    return latents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_mix(latents_1, latents_2, style_layers_idxs, synth_image_fun, alpha=1):\n",
    "    latents_1 = load_latents(latents_1)\n",
    "    latents_2 = load_latents(latents_2)\n",
    "    \n",
    "    assert latents_1.shape == latents_2.shape\n",
    "        \n",
    "    # crossover option, from latents_1 to latents_2\n",
    "    mix_latents = latents_2.copy()\n",
    "    mix_latents[style_layers_idxs] = latents_1[style_layers_idxs] * alpha + mix_latents[style_layers_idxs] * (1-alpha)\n",
    "    \n",
    "    # store in case we want to export results from widget\n",
    "    global current_displayed_latents\n",
    "    current_displayed_latents = mix_latents\n",
    "    \n",
    "    # generate\n",
    "    gen_image = synth_image_fun(mix_latents[np.newaxis, :, :])\n",
    "    return gen_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setup plot image\n",
    "button = Button(description=\"Savefig\")\n",
    "\n",
    "dpi = 100\n",
    "fig, ax = plt.subplots(dpi=dpi, figsize=(IMG_SIZE[0]/dpi, IMG_SIZE[1]/2))\n",
    "fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0, wspace=0)\n",
    "plt.axis('off')\n",
    "im = ax.imshow(gen_image_fun(Gs, np.random.rand(1, Z_SIZE), noise_vars, Gs_kwargs))\n",
    "\n",
    "#prevent any output for this cell\n",
    "plt.close()\n",
    "\n",
    "# save current figure and latents\n",
    "def on_button_clicked(b):\n",
    "    dest_dir = res_dir / 'projection' / MODEL_NAME / SNAPSHOT_NAME / \"picked\"\n",
    "    timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
    "    \n",
    "    fig.savefig(dest_dir / (timestamp + '.png'), bbox_inches='tight')\n",
    "    \n",
    "    global current_displayed_latents\n",
    "    np.save(dest_dir / (timestamp + '.npy'), current_displayed_latents)\n",
    "    \n",
    "button.on_click(on_button_clicked)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = res_dir / 'projection' / MODEL_NAME / SNAPSHOT_NAME / '20200208_181509'\n",
    "entries = [p.name for p in data_dir.glob(\"*\") if p.is_dir()]\n",
    "entries.remove('tfrecords')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "display(button)\n",
    "@interact\n",
    "def i_style_mixing(entry1 = entries, entry2 = entries,\n",
    "                   from_layer = np.arange(0, 18), to_layer = np.arange(0, 18),\n",
    "                   alpha = (-0.5, 1.5)):\n",
    "    assert from_layer <= to_layer\n",
    "    \n",
    "    latents_1 =res_dir / 'projection' / \"image_latents1000.npy\"\n",
    "    latents_2 = res_dir / 'projection' / MODEL_NAME / SNAPSHOT_NAME / \"image_latents1000.npy\"\n",
    "    \n",
    "    gen_image = generate_mix(latents_1, latents_2, \n",
    "                             style_layers_idxs=np.arange(from_layer, to_layer), \n",
    "                             synth_image_fun=lambda dlatens : synth_image_fun(Gs, dlatens, randomize_noise=True),\n",
    "                             alpha=alpha)\n",
    "    im.set_data(gen_image)\n",
    "    display(fig)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Latents Transition/Morphing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#PLOT_IMG_SHAPE = (512, 512, 3)\n",
    "PLOT_IMG_SHAPE = (IMG_SIZE[0], IMG_SIZE[1], 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "render_dir = res_dir / MODEL_NAME / SNAPSHOT_NAME / \"explore_latent\"\n",
    "\n",
    "nb_samples = 2\n",
    "nb_transition_frames = 450\n",
    "nb_frames = min(450, (nb_samples-1)*nb_transition_frames)\n",
    "\n",
    "psi=1\n",
    "\n",
    "# run animation\n",
    "for i in range(0, 2):\n",
    "    # setup the passed latents\n",
    "    z_s = np.random.randn(nb_samples, Z_SIZE)\n",
    "    #latents = Gs.components.mapping.run(z_s, None)\n",
    "    passed_latents=z_s\n",
    "    \n",
    "    animate_latent_transition(latent_vectors=passed_latents, \n",
    "                             #gen_image_fun=synth_image_fun,\n",
    "                             gen_image_fun=lambda latents : gen_image_fun(latents, truncation_psi=psi),\n",
    "                             gen_latent_fun=lambda z_s, i: gen_latent_linear(passed_latents, i, nb_transition_frames),\n",
    "                             img_size=PLOT_IMG_SHAPE,\n",
    "                             nb_frames=nb_frames,\n",
    "                             render_dir=render_dir / \"transitions\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Explore PSI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#PLOT_IMG_SHAPE = (512, 512, 3)\n",
    "PLOT_IMG_SHAPE = (IMG_SIZE[0], IMG_SIZE[1], 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "render_dir = res_dir / MODEL_NAME / SNAPSHOT_NAME / 'explore_latent'\n",
    "\n",
    "nb_samples = 20\n",
    "nb_transition_frames = 24\n",
    "nb_frames = min(450, (nb_samples-1)*nb_transition_frames)\n",
    "\n",
    "# setup the passed latents\n",
    "z_s = np.random.randn(nb_samples, Z_SIZE)\n",
    "#latents = Gs.components.mapping.run(z_s, None)\n",
    "passed_latents = z_s\n",
    "\n",
    "# run animation\n",
    "#[2., 1.5, 1., 0.7, 0.5, 0., -0.5, -0.7, -1., -1.5, -2.]\n",
    "for psi in np.linspace(-0.5, 1.5, 9):\n",
    "    animate_latent_transition(latent_vectors=passed_latents, \n",
    "                             #gen_image_fun=synth_image_fun,\n",
    "                             gen_image_fun=lambda latents : gen_image_fun(latents, truncation_psi=psi),\n",
    "                             gen_latent_fun=lambda z_s, i: gen_latent_linear(passed_latents, i, nb_transition_frames),\n",
    "                             img_size=PLOT_IMG_SHAPE,\n",
    "                             nb_frames=nb_frames,\n",
    "                             render_dir=render_dir / 'psi',\n",
    "                             file_prefix='psi{}'.format(str(psi).replace('.', '_')[:5]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Explore Latents Indexes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#PLOT_IMG_SHAPE = (512, 512, 3)\n",
    "PLOT_IMG_SHAPE = (IMG_SIZE[0], IMG_SIZE[1], 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "render_dir = res_dir / MODEL_NAME / SNAPSHOT_NAME / \"explore_latent\"\n",
    "\n",
    "nb_transition_frames = 48\n",
    "\n",
    "# random list of z vectors\n",
    "#rand_idx = np.random.randint(len(X_train))\n",
    "z_start = np.random.randn(1, Z_SIZE)\n",
    "#dlatents = Gs.components.mapping.run(z_start, None, dlatent_broadcast=None)\n",
    "#vals = np.linspace(-2., 2., nb_transition_frames)\n",
    "nb_styles = dlatents.shape[0]\n",
    "stylelatent_vals= np.random.randn(nb_transition_frames, Z_SIZE) + np.linspace(-1., 1., nb_transition_frames)[:, np.newaxis]\n",
    "\n",
    "for z_idx in range(nb_styles):\n",
    "    animate_latent_transition(latent_vectors=dlatents[0], \n",
    "                             gen_image_fun=synth_image_fun,\n",
    "                             gen_latent_fun=lambda z_s, i: gen_latent_style_idx(dlatents[0], i, z_idx, stylelatent_vals),\n",
    "                             img_size=PLOT_IMG_SHAPE,\n",
    "                             nb_frames=nb_transition_frames,\n",
    "                             render_dir=render_dir / 'latent_indexes')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "data-science",
   "language": "python",
   "name": "data-science"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  },
  "notify_time": "30",
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": true,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
